/*
 * Copyright © 2019, 2021 Apple Inc. and the ServiceTalk project authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.servicetalk.grpc.protoc;

import com.google.protobuf.DescriptorProtos;
import com.google.protobuf.DescriptorProtos.DescriptorProto;
import com.google.protobuf.DescriptorProtos.EnumDescriptorProto;
import com.google.protobuf.DescriptorProtos.FileDescriptorProto;
import com.google.protobuf.DescriptorProtos.FileOptions;
import com.google.protobuf.DescriptorProtos.MethodDescriptorProto;
import com.google.protobuf.DescriptorProtos.ServiceDescriptorProto;
import com.google.protobuf.compiler.PluginProtos.CodeGeneratorResponse;
import com.google.protobuf.compiler.PluginProtos.CodeGeneratorResponse.File;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.TypeSpec;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;

import static com.squareup.javapoet.MethodSpec.constructorBuilder;
import static io.servicetalk.grpc.protoc.StringUtils.isNotNullNorEmpty;
import static io.servicetalk.grpc.protoc.StringUtils.sanitizeIdentifier;
import static javax.lang.model.element.Modifier.FINAL;
import static javax.lang.model.element.Modifier.PRIVATE;
import static javax.lang.model.element.Modifier.PUBLIC;
import static javax.lang.model.element.Modifier.STATIC;

/**
 * A single protoc file for which we will be generating classes
 */
final class FileDescriptor implements GenerationContext {
    private static final String GENERATED_BY_COMMENT = "Generated by ServiceTalk gRPC protoc plugin";
    /**
     * Inferred behavior from protobuf-java is that if no suffix is explicitly provided the root file name will have
     * this suffix. See
     * <a href="https://github.com/protocolbuffers/protobuf/blob/v3.19.4/src/google/protobuf/compiler/java/java_name_resolver.cc#L49">java_name_resolver.cc</a>
     */
    private static final String OUTER_CLASS_SUFFIX = "OuterClass";
    private final FileDescriptorProto protoFile;
    @Nullable
    private final String protoPackageName;
    private final boolean deprecated;
    private final boolean multipleClassFiles;
    private final String javaPackageName;
    private final String outerClassName;
    private final String javaOuterScope;
    @Nullable
    private final String typeNameSuffix;
    private final List<TypeSpec.Builder> serviceClassBuilders;
    private final Set<String> reservedJavaTypeName = new HashSet<>();
    private final Map<TypeSpec.Builder, ServiceDescriptorProto> protoForServiceBuilder;

    /**
     * A single protoc file for which we will be generating classes
     *
     * @param protoFile The file
     * @param typeNameSuffix optional suffix to be appended to service class names.
     */
    FileDescriptor(final FileDescriptorProto protoFile,
                   @Nullable final String typeNameSuffix) {
        this.protoFile = protoFile;
        final String sanitizedProtoFileName = sanitizeFileName(protoFile.getName());
        protoPackageName = protoFile.hasPackage() ? protoFile.getPackage() : null;
        this.typeNameSuffix = typeNameSuffix;

        if (protoFile.hasOptions()) {
            final FileOptions fileOptions = protoFile.getOptions();
            deprecated = fileOptions.hasDeprecated() && fileOptions.getDeprecated();
            multipleClassFiles = fileOptions.hasJavaMultipleFiles() && fileOptions.getJavaMultipleFiles();
            javaPackageName = fileOptions.hasJavaPackage() ? fileOptions.getJavaPackage() :
                    inferJavaPackageName(protoPackageName, sanitizedProtoFileName);
            outerClassName = fileOptions.hasJavaOuterClassname() ?
                    sanitizeClassName(fileOptions.getJavaOuterClassname()) :
                    inferOuterClassName(sanitizedProtoFileName, protoFile);
        } else {
            deprecated = false;
            multipleClassFiles = false;
            javaPackageName = inferJavaPackageName(protoPackageName, sanitizedProtoFileName);
            outerClassName = inferOuterClassName(sanitizedProtoFileName, protoFile);
        }
        javaOuterScope = multipleClassFiles ? javaPackageName() : javaPackageName() + '.' + outerJavaClassName();
        reservedJavaTypeName.add(outerClassName);

        serviceClassBuilders = new ArrayList<>(protoFile.getServiceCount());
        protoForServiceBuilder = new HashMap<>();
    }

    String protoFileName() {
        return protoFile.getName();
    }

    @Nullable
    String getProtoPackageName() {
        return protoPackageName;
    }

    List<ServiceDescriptorProto> protoServices() {
        return protoFile.getServiceList();
    }

    Map<String, ClassName> messageTypesMap() {
        final Map<String, ClassName> messageTypesMap = new HashMap<>(protoFile.getMessageTypeCount());
        addMessageTypes(protoFile.getMessageTypeList(), protoPackageName != null ? '.' + protoPackageName : null,
                javaOuterScope, messageTypesMap);
        return messageTypesMap;
    }

    DescriptorProtos.SourceCodeInfo sourceCodeInfo() {
        return protoFile.getSourceCodeInfo();
    }

    private static void addMessageTypes(final List<DescriptorProto> messageTypes,
                                        @Nullable final String parentProtoScope,
                                        final String parentJavaScope,
                                        final Map<String, ClassName> messageTypesMap) {
        messageTypes.forEach(t -> {
            final String protoTypeName = parentProtoScope != null ?
                    (parentProtoScope + '.' + t.getName()) : '.' + t.getName();
            final ClassName className = ClassName.get(parentJavaScope, t.getName());
            messageTypesMap.put(protoTypeName, className);

            addMessageTypes(t.getNestedTypeList(), protoTypeName, className.canonicalName(), messageTypesMap);
        });
    }

    @Override
    public String deconflictJavaTypeName(final String name) {
        if (reservedJavaTypeName.add(name)) {
            return name;
        }

        int i = 0;
        String uniqueName;
        do {
            uniqueName = name + i;
            i++;
        } while (!reservedJavaTypeName.add(uniqueName));

        return uniqueName;
    }

    @Override
    public String deconflictJavaTypeName(final String outerClassName, final String name) {
        return javaOuterScope + '.' + outerClassName + '.' + deconflictJavaTypeName(sanitizeClassName(name));
    }

    @Override
    public ServiceClassBuilder newServiceClassBuilder(final ServiceDescriptorProto serviceProto) {
        final String rawClassName = typeNameSuffix == null ? serviceProto.getName() :
                serviceProto.getName() + typeNameSuffix;
        final String className = deconflictJavaTypeName(sanitizeClassName(rawClassName));

        final TypeSpec.Builder builder = TypeSpec.classBuilder(className)
                .addModifiers(PUBLIC, FINAL)
                .addMethod(constructorBuilder()
                        .addModifiers(PRIVATE)
                        .addComment("no instances")
                        .build());

        // Mark deprecated if either the whole proto or the individual service is deprecated
        if (deprecated || serviceProto.hasOptions() && serviceProto.getOptions().hasDeprecated() &&
                serviceProto.getOptions().getDeprecated()) {
            builder.addAnnotation(Deprecated.class);
        }

        serviceClassBuilders.add(builder);
        protoForServiceBuilder.put(builder, serviceProto);
        return new ServiceClassBuilder(builder, className);
    }

    @Override
    public String methodPath(final ServiceDescriptorProto serviceProto, final MethodDescriptorProto methodProto) {
        final StringBuilder sb = new StringBuilder(128).append('/');
        if (isNotNullNorEmpty(protoPackageName)) {
            sb.append(protoPackageName).append('.');
        }
        sb.append(serviceProto.getName()).append('/').append(methodProto.getName());
        return sb.toString();
    }

    void writeTo(final CodeGeneratorResponse.Builder responseBuilder) {
        if (serviceClassBuilders.isEmpty()) {
            return;
        }

        if (!multipleClassFiles) {
            // All source code should be put into 1 file, use the file that is generated by protoc,
            // which is done by writing to a .java file whose name is calculated to match the one that protoc
            // will create (i.e. this file name is not provided in CodeGeneratorRequest)
            final String fileName = calculateFileName(javaPackageName(), outerJavaClassName());

            insertSingleFileContent("// " + GENERATED_BY_COMMENT, fileName, responseBuilder);
            for (final TypeSpec.Builder builder : serviceClassBuilders) {
                String content = addInsertionPoint(
                    builder.addModifiers(STATIC).build().toString(),
                    protoForServiceBuilder.get(builder).getName()
                );
                insertSingleFileContent(content, fileName, responseBuilder);
            }
            return;
        }

        // write each service to its own file
        final String packageName = javaPackageName();
        for (final TypeSpec.Builder builder : serviceClassBuilders) {
            final TypeSpec serviceType = builder.build();
            ServiceDescriptorProto serviceDescriptorProto = protoForServiceBuilder.get(builder);
            final File.Builder fileBuilder = File.newBuilder();
            fileBuilder.setName(calculateFileName(packageName, serviceType.name));

            final JavaFile javaFile = JavaFile.builder(packageName, serviceType)
                    .indent("    ")
                    .addFileComment(GENERATED_BY_COMMENT)
                    .build();

            fileBuilder.setContent(addInsertionPoint(javaFile.toString(), serviceDescriptorProto.getName()));
            responseBuilder.addFile(fileBuilder.build());
        }
    }

    private String addInsertionPoint(String content, String name) {
        String fqn = protoPackageName != null ? protoPackageName + '.' + name : name;
        content = content.replaceAll("class __" + fqn + " \\{\n *}", insertionPoint(fqn));
        return content;
    }

    static String insertionPoint(final String fqn) {
        return "// @@protoc_insertion_point(service_scope:" + fqn + ')';
    }

    private static void insertSingleFileContent(final String content, String fileName,
                                                final CodeGeneratorResponse.Builder responseBuilder) {
        final File.Builder fileBuilder = File.newBuilder();
        fileBuilder.setName(fileName);
        fileBuilder.setInsertionPoint("outer_class_scope");
        fileBuilder.setContent(content + "\n");
        responseBuilder.addFile(fileBuilder.build());
    }

    private String outerJavaClassName() {
        return outerClassName;
    }

    private String javaPackageName() {
        return javaPackageName;
    }

    private static String inferOuterClassName(String sanitizedProtoFileName, FileDescriptorProto protoFile) {
        final String sanitizeClassName = sanitizeClassName(sanitizedProtoFileName);
        return hasConflictingClassName(sanitizeClassName, protoFile) ?
                sanitizeClassName + OUTER_CLASS_SUFFIX : sanitizeClassName;
    }

    /**
     * See <a href="https://github.com/protocolbuffers/protobuf/blob/v3.19.4/src/google/protobuf/compiler/java/java_name_resolver.cc#L192-L214">java_name_resolver.cc</a>.
     * @param sanitizedClassName The sanitized classname to check for conflicts.
     * @param protoFile The {@link FileDescriptorProto} to search for conflicting names in.
     * @return {@code true} if there is a name conflict, {@code false} otherwise.
     */
    private static boolean hasConflictingClassName(String sanitizedClassName, FileDescriptorProto protoFile) {
        for (EnumDescriptorProto enumDescriptor : protoFile.getEnumTypeList()) {
            if (enumDescriptor.getName().equals(sanitizedClassName)) {
                return true;
            }
        }
        for (ServiceDescriptorProto serviceDescriptor : protoFile.getServiceList()) {
            if (serviceDescriptor.getName().equals(sanitizedClassName)) {
                return true;
            }
        }
        for (DescriptorProto typeDescriptor : protoFile.getMessageTypeList()) {
            if (hasConflictingClassName(sanitizedClassName, typeDescriptor)) {
                return true;
            }
        }
        return false;
    }

    private static boolean hasConflictingClassName(String sanitizedClassName, DescriptorProto typeDescriptor) {
        if (typeDescriptor.getName().equals(sanitizedClassName)) {
            return true;
        }
        for (DescriptorProto nestedTypeDescriptor : typeDescriptor.getNestedTypeList()) {
            if (hasConflictingClassName(sanitizedClassName, nestedTypeDescriptor)) {
                return true;
            }
        }
        for (EnumDescriptorProto enumDescriptor : typeDescriptor.getEnumTypeList()) {
            if (enumDescriptor.getName().equals(sanitizedClassName)) {
                return true;
            }
        }
        return false;
    }

    private static String inferJavaPackageName(@Nullable String protoPackageName, String sanitizedProtoFileName) {
        return isNotNullNorEmpty(protoPackageName) ? protoPackageName : sanitizeClassName(sanitizedProtoFileName);
    }

    private static String sanitizeFileName(final String v) {
        int i = v.lastIndexOf('/');
        final int j = v.lastIndexOf('.');
        if (i != -1 && j != -1) {
            if (++i >= v.length()) {
                throw new IllegalArgumentException("Illegal file name: " + v);
            }
            return v.substring(i, j);
        } else if (j != -1) {
            return v.substring(0, j);
        }
        return v;
    }

    private static String sanitizeClassName(final String v) {
        return sanitizeIdentifier(v, false);
    }

    private static String calculateFileName(final String packageName, final String className) {
        return packageName.replace('.', '/') + '/' + className + ".java";
    }
}
